package com.hanxiaozhang.tree.binarytreerecursion;

import com.hanxiaozhang.tree.BinaryTreeUtil;
import com.hanxiaozhang.tree.Node;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;

/**
 * 〈一句话功能简述〉<br>
 * 〈给定一棵二叉树的头节点head，和另外两个节点a和b。返回a和b的最低公共祖先〉
 *
 * @author hanxinghua
 * @create 2021/10/30
 * @since 1.0.0
 */
public class lowestAncestor {

    public static void main(String[] args) {
        int maxLevel = 4;
        int maxValue = 100;
        int testTimes = 1000000;
        for (int i = 0; i < testTimes; i++) {
            Node head = generateRandomBST(maxLevel, maxValue);
            Node o1 = pickRandomOne(head);
            Node o2 = pickRandomOne(head);
            if (lowestAncestor1(head, o1, o2) != lowestAncestor2(head, o1, o2)) {
                System.out.println("Oops!");
            }
        }
        System.out.println("finish!");
    }


    /**
     * 方法1
     *
     * @param head
     * @param o1
     * @param o2
     * @return
     */
    public static Node lowestAncestor1(Node head, Node o1, Node o2) {
        if (head == null) {
            return null;
        }
        // key 的 父节点是 value
        HashMap<Node, Node> parentMap = new HashMap<>();
        // 头的父节点是null
        parentMap.put(head, null);
        // 递归方法添加parentMap
        fillParentMap(head, parentMap);
        // 集合
        HashSet<Node> o1Set = new HashSet<>();
        // 把o1链上的父亲节点都添加到Set中
        Node cur = o1;
        o1Set.add(cur);
        while (parentMap.get(cur) != null) {
            cur = parentMap.get(cur);
            o1Set.add(cur);
        }
        // 把o2链上的不存在的父亲节点都添加到Set中
        cur = o2;
        while (!o1Set.contains(cur)) {
            cur = parentMap.get(cur);
        }
        return cur;
    }

    /**
     * 递归 parentMap
     *
     * @param head
     * @param parentMap
     */
    public static void fillParentMap(Node head, HashMap<Node, Node> parentMap) {
        if (head.left != null) {
            parentMap.put(head.left, head);
            fillParentMap(head.left, parentMap);
        }
        if (head.right != null) {
            parentMap.put(head.right, head);
            fillParentMap(head.right, parentMap);
        }
    }

    /**
     * 方法2
     *
     * @param head
     * @param o1
     * @param o2
     * @return
     */
    public static Node lowestAncestor2(Node head, Node o1, Node o2) {
        return process(head, o1, o2).ans;
    }

    /**
     * 递归
     *
     * @param head
     * @param o1
     * @param o2
     * @return
     */
    public static Info process(Node head, Node o1, Node o2) {

        if (head == null) {
            return new Info(null, false, false);
        }
        Info leftInfo = process(head.left, o1, o2);
        Info rightInfo = process(head.right, o1, o2);
        // head == o1 -> 头结点是否是o1 || leftInfo.findO1 || rightInfo.findO1
        boolean findO1 = head == o1 || leftInfo.findO1 || rightInfo.findO1;
        boolean findO2 = head == o2 || leftInfo.findO2 || rightInfo.findO2;
        // o1和o2最初的交汇点在哪里
        // 1)在左树上已经提取交汇了
        // 2)在右树上已经提取交汇了
        // 3)没有在左树或右树上提前交汇了 o1 o2全了
        // 4)
        Node ans = null;
        // 左不为空
        if (leftInfo.ans != null) {
            ans = leftInfo.ans;
        }
        // 右不为空
        if (rightInfo.ans != null) {
            ans = rightInfo.ans;
        }
        // ans == null && findO1 与 findO2都是true  ans 就是头结点
        if (ans == null && findO1 && findO2) {
            ans = head;
        }

        return new Info(ans, findO1, findO2);
    }


    public static class Info {
        /**
         * o1 o2 最初交汇的点
         */
        public Node ans;

        /**
         * 是否发现o1
         */
        public boolean findO1;

        /**
         * 是否发现o2
         */
        public boolean findO2;

        public Info(Node a, boolean f1, boolean f2) {
            ans = a;
            findO1 = f1;
            findO2 = f2;
        }
    }

    // for test
    public static Node generateRandomBST(int maxLevel, int maxValue) {
        return BinaryTreeUtil.generate(1, maxLevel, maxValue);
    }


    // for test
    public static Node pickRandomOne(Node head) {
        if (head == null) {
            return null;
        }
        ArrayList<Node> arr = new ArrayList<>();
        fillPrelist(head, arr);
        int randomIndex = (int) (Math.random() * arr.size());
        return arr.get(randomIndex);
    }

    // for test
    public static void fillPrelist(Node head, ArrayList<Node> arr) {
        if (head == null) {
            return;
        }
        arr.add(head);
        fillPrelist(head.left, arr);
        fillPrelist(head.right, arr);
    }


}
